#include <torch/extension.h>
#include <vector>
#include <iostream>

// s'(z) = (1 - s(z)) * s(z) sigmoid 的导数计算
torch::Tensor d_sigmoid(torch::Tensor z) {
  auto s = torch::sigmoid(z);
  return (1 - s) * s;
}

// tanh'(z) = 1 - tanh^2(z)  tanh 的 导数计算
torch::Tensor d_tanh(torch::Tensor z) {
  return 1 - z.tanh().pow(2);
}

// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} , elu的导数计算
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
  auto e = z.exp();
  auto mask = (alpha * (e - 1)) < 0;
  return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}


std::vector<torch::Tensor> lltm_forward(
    torch::Tensor input,        // input_features
    torch::Tensor weights,      // (3*state_size, input_features + state_size)
    torch::Tensor bias,         // (batch_size, 3*state_size)
    torch::Tensor old_h,
    torch::Tensor old_cell) {
    // dim=1 上拼接矩阵, (3*state_size, state_size) (3*state_size, input_feature)
    auto X = torch::cat({old_h, input}, /*dim=*/1);

    // β * M + α * (m_1 @ m_2) 其中 β和α 默认是1， @ 矩阵乘法
    // (batch_size, 3*state_size) + (batch_size, ) @ (input_features+state_size, 3*state_size)
    auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
    // (batch_size, 3*state_size) --> 3个 (batch_size, state_size)
    auto gates = gate_weights.chunk(3, /*dim=*/1);

    auto input_gate = torch::sigmoid(gates[0]);
    auto output_gate = torch::sigmoid(gates[1]);
    // 过激活函数 elu
    auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);
    //std::cout << input_gate << std::endl;

    // 这里与LSTM相比没有遗忘门 直接 + 起来了， (batch_size, state_size) * (batch_size, state_size)
    auto new_cell = old_cell + candidate_cell * input_gate;
    // (batch_size, state_size)
    auto new_h = torch::tanh(new_cell) * output_gate;

    return {new_h,
            new_cell,
            input_gate,
            output_gate,
            candidate_cell,
            X,
            gate_weights};
    }

std::vector<torch::Tensor> lltm_backward(
    torch::Tensor grad_h,
    torch::Tensor grad_cell,
    torch::Tensor new_cell,
    torch::Tensor input_gate,
    torch::Tensor output_gate,
    torch::Tensor candidate_cell,
    torch::Tensor X,
    torch::Tensor gate_weights,
    torch::Tensor weights) {
    auto d_output_gate = torch::tanh(new_cell) * grad_h;
    auto d_tanh_new_cell = output_gate * grad_h;
    auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;

    auto d_old_cell = d_new_cell;
    auto d_candidate_cell = input_gate * d_new_cell;
    auto d_input_gate = candidate_cell * d_new_cell;

    auto gates = gate_weights.chunk(3, /*dim=*/1);
    d_input_gate *= d_sigmoid(gates[0]);
    d_output_gate *= d_sigmoid(gates[1]);
    d_candidate_cell *= d_elu(gates[2]);

    auto d_gates =
        torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);

    auto d_weights = d_gates.t().mm(X);
    auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);

    auto d_X = d_gates.mm(weights);
    const auto state_size = grad_h.size(1);
    auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
    auto d_input = d_X.slice(/*dim=*/1, state_size);

    return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &lltm_forward, "LLTM forward");
  m.def("backward", &lltm_backward, "LLTM backward");
}